{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Copyright 2020 Google LLC\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#     https://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<table align=\"left\">\n",
    "  <td>\n",
    "    <a href=\"https://colab.research.google.com/github.com/GoogleCloudPlatform/ai-platform-samples/blob/master/notebooks/samples/pytorch/text_classification_using_pytorch_and_ai_platform.ipynb\">\n",
    "      <img src=\"https://cloud.google.com/ml-engine/images/colab-logo-32px.png\" alt=\"Colab logo\"> Run in Colab\n",
    "    </a>\n",
    "  </td>\n",
    "  <td>\n",
    "    <a href=\"https://github.com/GoogleCloudPlatform/ai-platform-samples/blob/master/notebooks/samples/pytorch/text_classification_using_pytorch_and_ai_platform.ipynb\">\n",
    "      <img src=\"https://cloud.google.com/ml-engine/images/github-logo-32px.png\" alt=\"GitHub logo\">\n",
    "      View on GitHub\n",
    "    </a>\n",
    "  </td>\n",
    "</table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "UWvrShlZjZwr"
   },
   "source": [
    "## Overview\n",
    "\n",
    "This notebook illustrates the new feature of serving custom model prediction code on AI Platform. It allows us to execute arbitrary python pre-processing code prior to invoking a model, as well as post-processing on the produced predictions. In addition, you can use a model build by your **favourite Python-based ML framework**!\n",
    "\n",
    "This is all done server-side so that the client can pass data directly to AI Platform Serving in the unprocessed state.\n",
    "\n",
    "We will take advantage of this for text classification because it involves pre-processing that is not easily accomplished using native TensorFlow. Instead we will execute the the non TensorFlow pre-processing via python code on the server side.\n",
    "\n",
    "We will build a text classification model using [PyTorch](https://pytorch.org), while performing text preproessing using Keras. PyTorch is an open source deep learning platform that provides a seamless path from research prototyping to production deployment.\n",
    "\n",
    "\n",
    "## Dataset\n",
    "[Hacker News](https://bigquery.cloud.google.com/table/fh-bigquery:hackernews.stories) is one of many public datasets available in [BigQuery](https://cloud.google.com/bigquery). This dataset includes titles of articles from several data sources. For the following tutorial, we extracted the titles that belong to either GitHub, The New York Times, or TechCrunch, and saved them as CSV files in a publicly shared Cloud Storage bucket at the following location: **gs://cloud-training-demos/blogs/CMLE_custom_prediction**\n",
    "\n",
    "## Objective\n",
    "The goal of this tutorial is to:\n",
    "1. Process the data for text classification.\n",
    "2. Train a [PyTorch](https://pytorch.org) Text Classifier (locally).\n",
    "3. Deploy the [PyTorch](https://pytorch.org) Text Classifier, along with the preprocessing artifacts, to AI Platform Serving, using the Custom Online Prediction code.\n",
    "\n",
    "This tutorial focuses more on using this model with AI Platform Serving than on the design of the text classification model itself. For more details about text classification, please refer to [Google developer's Guide to Text Classification](https://developers.google.com/machine-learning/guides/text-classification/). \n",
    "\n",
    "### Costs\n",
    "\n",
    "This tutorial uses billable components of Google Cloud Platform (GCP):\n",
    "\n",
    "* AI Platform\n",
    "* Cloud Storage\n",
    "\n",
    "Learn about [AI Platform\n",
    "pricing](https://cloud.google.com/ml-engine/docs/pricing) and [Cloud Storage\n",
    "pricing](https://cloud.google.com/storage/pricing), and use the [Pricing\n",
    "Calculator](https://cloud.google.com/products/calculator/)\n",
    "to generate a cost estimate based on your projected usage."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Authenticate your GCP account\n",
    "\n",
    "**If you are using AI Platform Notebooks**, your environment is already\n",
    "authenticated. Skip this step."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**If you are using Colab**, run the cell below and follow the instructions\n",
    "when prompted to authenticate your account via oAuth.\n",
    "\n",
    "**Otherwise**, follow these steps:\n",
    "\n",
    "1. In the GCP Console, go to the [**Create service account key**\n",
    "   page](https://console.cloud.google.com/apis/credentials/serviceaccountkey).\n",
    "\n",
    "2. From the **Service account** drop-down list, select **New service account**.\n",
    "\n",
    "3. In the **Service account name** field, enter a name.\n",
    "\n",
    "4. From the **Role** drop-down list, select\n",
    "   **Machine Learning Engine > AI Platform Admin** and\n",
    "   **Storage > Storage Object Admin**.\n",
    "\n",
    "5. Click *Create*. A JSON file that contains your key downloads to your\n",
    "local environment.\n",
    "\n",
    "6. Enter the path to your service account key as the\n",
    "`GOOGLE_APPLICATION_CREDENTIALS` variable in the cell below and run the cell."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "66JlKmfzvPhN",
    "tags": [
     "no_execute"
    ]
   },
   "outputs": [],
   "source": [
    "import sys\n",
    "\n",
    "# If you are running this notebook in Colab, run this cell and follow the\n",
    "# instructions to authenticate your GCP account. This provides access to your\n",
    "# Cloud Storage bucket and lets you submit training jobs and prediction\n",
    "# requests.\n",
    "\n",
    "if 'google.colab' in sys.modules:\n",
    "  from google.colab import auth as google_auth\n",
    "  google_auth.authenticate_user()\n",
    "\n",
    "# If you are running this notebook locally, replace the string below with the\n",
    "# path to your service account key and run this cell to authenticate your GCP\n",
    "# account.\n",
    "else:\n",
    "  %env GOOGLE_APPLICATION_CREDENTIALS ''"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "eQWHJ3X7vLeQ"
   },
   "source": [
    "Run the following cell to install Python dependencies needed to train the model locally. When you run the training job in AI Platform,\n",
    "dependencies are preinstalled based on the [runtime\n",
    "version](https://cloud.google.com/ml-engine/docs/tensorflow/runtime-version-list)\n",
    "you choose."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "PEuI7Hw5PquP",
    "tags": [
     "no_execute"
    ]
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 649
    },
    "colab_type": "code",
    "id": "r733GnVjSwgp",
    "outputId": "1c6b2d77-2534-4d02-f56f-d4a6620a4a09"
   },
   "outputs": [],
   "source": [
    "!pip install tensorflow==1.15.2 --user\n",
    "!pip install torch --user"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "colab_type": "code",
    "id": "bupGswGOcMW2",
    "outputId": "aba75458-3190-4d34-f9f5-b8a2f9ac3bb2"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import torch\n",
    "import os\n",
    "\n",
    "print(tf.__version__) \n",
    "print(torch.__version__) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Set up your GCP project\n",
    "\n",
    "**The following steps are required, regardless of your notebook environment.**\n",
    "\n",
    "1. [Select or create a GCP project.](https://console.cloud.google.com/cloud-resource-manager)\n",
    "\n",
    "2. [Make sure that billing is enabled for your project.](https://cloud.google.com/billing/docs/how-to/modify-project)\n",
    "\n",
    "3. [Enable the AI Platform (\"Cloud Machine Learning Engine\") and Compute Engine APIs.](https://console.cloud.google.com/flows/enableapi?apiid=ml.googleapis.com,compute_component)\n",
    "\n",
    "4. Enter your project ID in the cell below. Then run the  cell to make sure the\n",
    "Cloud SDK uses the right project for all the commands in this notebook.\n",
    "\n",
    "**Note**: Jupyter runs lines prefixed with `!` as shell commands, and it interpolates Python variables prefixed with `$` into these commands."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ewdEsJuVvNqm"
   },
   "outputs": [],
   "source": [
    "PROJECT_ID = '[your-project-id]' # TODO (Set up your GCP Project name)\n",
    "!gcloud config set project {PROJECT_ID}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create a Cloud Storage bucket\n",
    "\n",
    "**The following steps are required, regardless of your notebook environment.**\n",
    "\n",
    "When you submit a training job using the Cloud SDK, you upload a Python package\n",
    "containing your training code to a Cloud Storage bucket. AI Platform runs\n",
    "the code from this package. In this tutorial, AI Platform also saves the\n",
    "trained model that results from your job in the same bucket. You can then\n",
    "create an AI Platform model version based on this output in order to serve\n",
    "online predictions.\n",
    "\n",
    "Set the name of your Cloud Storage bucket below. It must be unique across all\n",
    "Cloud Storage buckets. \n",
    "\n",
    "You may also change the `REGION` variable, which is used for operations\n",
    "throughout the rest of this notebook. Make sure to [choose a region where Cloud\n",
    "AI Platform services are\n",
    "available](https://cloud.google.com/ml-engine/docs/tensorflow/regions)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "BUCKET_NAME = '[your-bucket-name]' #@param {type:\"string\"}\n",
    "REGION = 'us-central1' #@param {type:\"string\"}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "ROOT='torch_text_classification'\n",
    "MODEL_DIR=os.path.join(ROOT,'models')\n",
    "PACKAGES_DIR=os.path.join(ROOT,'packages')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 102
    },
    "colab_type": "code",
    "id": "Qe49zwLLSXzu",
    "outputId": "58d8081d-b58e-494b-e504-065a768a9a36"
   },
   "outputs": [],
   "source": [
    "# Delete any previous artifacts from Google Cloud Storage\n",
    "!gsutil rm -r gs://{BUCKET_NAME}/{ROOT}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "V12s5BFCkYhC"
   },
   "source": [
    "## Download and Explore Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 119
    },
    "colab_type": "code",
    "id": "ww6dAQ7EkHzZ",
    "outputId": "af58faac-c405-4e14-a644-a7f54dfa08c9"
   },
   "outputs": [],
   "source": [
    "%%bash\n",
    "gsutil cp gs://cloud-training-demos/blogs/CMLE_custom_prediction/keras_text_pre_processing/train.tsv .\n",
    "gsutil cp gs://cloud-training-demos/blogs/CMLE_custom_prediction/keras_text_pre_processing/eval.tsv ."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "g1XEyk4toH76"
   },
   "outputs": [],
   "source": [
    "!head eval.tsv"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "bHqYoDlFsw02"
   },
   "source": [
    "## Preprocessing"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "yIarI3u3r6vd"
   },
   "source": [
    "### Pre-processing class to be used in both training and serving"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "GdkNipQtr8-P",
    "outputId": "8dd22601-e4f6-4a2d-9be6-22f7fe94c493"
   },
   "outputs": [],
   "source": [
    "%%writefile preprocess.py\n",
    "\n",
    "from tensorflow.python.keras.preprocessing import sequence\n",
    "from tensorflow.keras.preprocessing import text\n",
    "\n",
    "\n",
    "class TextPreprocessor(object):\n",
    "    def __init__(self, vocab_size, max_sequence_length):\n",
    "        self._vocabb_size = vocab_size\n",
    "        self._max_sequence_length = max_sequence_length\n",
    "        self._tokenizer = None\n",
    "\n",
    "    def fit(self, text_list):        \n",
    "        # Create vocabulary from input corpus.\n",
    "        tokenizer = text.Tokenizer(num_words=self._vocabb_size)\n",
    "        tokenizer.fit_on_texts(text_list)\n",
    "        self._tokenizer = tokenizer\n",
    "\n",
    "    def transform(self, text_list):        \n",
    "        # Transform text to sequence of integers\n",
    "        text_sequence = self._tokenizer.texts_to_sequences(text_list)\n",
    "        # Fix sequence length to max value. Sequences shorter than the length are\n",
    "        # padded in the beginning and sequences longer are truncated\n",
    "        # at the beginning.\n",
    "        padded_text_sequence = sequence.pad_sequences(\n",
    "          text_sequence, maxlen=self._max_sequence_length)\n",
    "        return padded_text_sequence"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "pEk_bHo08a1n"
   },
   "source": [
    "### Test Prepocessing Locally"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "sKtHuVUw8kFO",
    "outputId": "0d881cc9-b898-42a6-b5c9-6faaf79d4da4"
   },
   "outputs": [],
   "source": [
    "from preprocess import TextPreprocessor\n",
    "\n",
    "processor = TextPreprocessor(5, 5)\n",
    "processor.fit(['hello machine learning'])\n",
    "processor.transform(['hello machine learning'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "lsTGmBn4s2fC"
   },
   "source": [
    "## Model Creation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "eFKEHCk_WYQU"
   },
   "source": [
    "### Metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "P-1_lAwsWNFf"
   },
   "outputs": [],
   "source": [
    "CLASSES = {'github': 0, 'nytimes': 1, 'techcrunch': 2}  # label-to-int mapping\n",
    "NUM_CLASSES = 3\n",
    "VOCAB_SIZE = 20000  # Limit on the number vocabulary size used for tokenization\n",
    "MAX_SEQUENCE_LENGTH = 50  # Sentences will be truncated/padded to this length"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "jkzmfu3jV-78"
   },
   "source": [
    "### Prepare data for training and evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "HDX_3jquWCBY"
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "from preprocess import TextPreprocessor\n",
    "\n",
    "def load_data(train_data_path, eval_data_path):\n",
    "    # Parse CSV using pandas\n",
    "    column_names = ('label', 'text')\n",
    "    \n",
    "    df_train = pd.read_csv(train_data_path, names=column_names, sep='\\t')\n",
    "    df_train = df_train.sample(frac=1)\n",
    "    \n",
    "    df_eval = pd.read_csv(eval_data_path, names=column_names, sep='\\t')\n",
    "\n",
    "    return ((list(df_train['text']), np.array(df_train['label'].map(CLASSES))),\n",
    "            (list(df_eval['text']), np.array(df_eval['label'].map(CLASSES))))\n",
    "\n",
    "\n",
    "((train_texts, train_labels), (eval_texts, eval_labels)) = load_data(\n",
    "       'train.tsv', 'eval.tsv')\n",
    "\n",
    "# Create vocabulary from training corpus.\n",
    "processor = TextPreprocessor(VOCAB_SIZE, MAX_SEQUENCE_LENGTH)\n",
    "processor.fit(train_texts)\n",
    "\n",
    "# Preprocess the data\n",
    "train_texts_vectorized = processor.transform(train_texts)\n",
    "eval_texts_vectorized = processor.transform(eval_texts)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "r9sDfoXesC1d"
   },
   "source": [
    "### Build the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "E7hifM2esEe9",
    "outputId": "ca199fd7-bb34-4cee-c1fa-c69bbc663380"
   },
   "outputs": [],
   "source": [
    "%%writefile torch_model.py\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class TorchTextClassifier(nn.Module):\n",
    "    \n",
    "    def __init__(self, vocab_size, embedding_dim, seq_length, num_classes, \n",
    "                 num_filters, kernel_size, pool_size, dropout_rate):\n",
    "        super(TorchTextClassifier, self).__init__()\n",
    "\n",
    "        self.embeddings = nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_dim)\n",
    "        \n",
    "        self.conv1 = nn.Conv1d(seq_length, num_filters, kernel_size)\n",
    "        self.max_pool1 = nn.MaxPool1d(pool_size)\n",
    "        self.conv2 = nn.Conv1d(num_filters, num_filters*2, kernel_size)\n",
    "        \n",
    "        self.dropout = nn.Dropout(dropout_rate)\n",
    "        self.dense = nn.Linear(num_filters*2, num_classes)\n",
    "        \n",
    "\n",
    "    def forward(self, x):\n",
    "        \n",
    "        x = self.embeddings(x)\n",
    "        x = self.dropout(x)\n",
    "\n",
    "        x = self.conv1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.max_pool1(x)\n",
    "        \n",
    "        x = self.conv2(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool1d(x, x.size()[2]).squeeze(2)\n",
    "        \n",
    "        x = self.dropout(x)\n",
    "        x = self.dense(x)\n",
    "        x = F.softmax(x, 1)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "53Gef0A4sjdV"
   },
   "source": [
    "### Train and save the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 391
    },
    "colab_type": "code",
    "id": "PCkmyHfASX0g",
    "outputId": "354cb0fc-bbcc-452a-cf50-17bf0fd5b0b4"
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.autograd import Variable\n",
    "import torch.nn.functional as F\n",
    "\n",
    "LEARNING_RATE=.001\n",
    "FILTERS=64\n",
    "DROPOUT_RATE=0.2\n",
    "EMBEDDING_DIM=200\n",
    "KERNEL_SIZE=3\n",
    "POOL_SIZE=3\n",
    "\n",
    "NUM_EPOCH=1\n",
    "BATCH_SIZE=128\n",
    "\n",
    "train_size = len(train_texts)\n",
    "steps_per_epoch = int(len(train_labels)/BATCH_SIZE)\n",
    "\n",
    "print(\"Train size: {}\".format(train_size))\n",
    "print(\"Batch size: {}\".format(BATCH_SIZE))\n",
    "print(\"Number of epochs: {}\".format(NUM_EPOCH))\n",
    "print(\"Steps per epoch: {}\".format(steps_per_epoch))\n",
    "print(\"Vocab Size: {}\".format(VOCAB_SIZE))\n",
    "print(\"Embed Dimensions: {}\".format(EMBEDDING_DIM))\n",
    "print(\"Sequence Length: {}\".format(MAX_SEQUENCE_LENGTH))\n",
    "print(\"\")\n",
    "\n",
    "\n",
    "def get_batch(step):\n",
    "    start_index = step*BATCH_SIZE\n",
    "    end_index = start_index + BATCH_SIZE\n",
    "    x = Variable(torch.Tensor(train_texts_vectorized[start_index:end_index]).long())\n",
    "    y = Variable(torch.Tensor(train_labels[start_index:end_index]).long())\n",
    "    return x, y\n",
    "\n",
    "\n",
    "from torch_model import TorchTextClassifier\n",
    "\n",
    "model = TorchTextClassifier(VOCAB_SIZE, \n",
    "                            EMBEDDING_DIM, \n",
    "                            MAX_SEQUENCE_LENGTH, \n",
    "                            NUM_CLASSES, \n",
    "                            FILTERS, \n",
    "                            KERNEL_SIZE, \n",
    "                            POOL_SIZE, \n",
    "                            DROPOUT_RATE)\n",
    "\n",
    "model.train()\n",
    "loss_metric = F.cross_entropy\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)\n",
    "\n",
    "for epoch in range(NUM_EPOCH):\n",
    "    for step in range(steps_per_epoch):\n",
    "        x, y = get_batch(step)\n",
    "        optimizer.zero_grad()\n",
    "        y_pred = model(x)\n",
    "        loss = loss_metric(y_pred, y) \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        if step % 50 == 0:\n",
    "            print('Batch [{}/{}] Loss: {}'.format(step+1, steps_per_epoch, round(loss.item(),5)))\n",
    "    print('Epoch [{}/{}] Loss: {}'.format(epoch+1, NUM_EPOCH, round(loss.item(),5)))\n",
    "print('Final Loss: {}'.format(epoch+1, NUM_EPOCH, round(loss.item(),5)))\n",
    "\n",
    "torch.save(model, 'torch_saved_model.pt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "MakbLwTbuMsZ"
   },
   "source": [
    "### Save pre-processing object\n",
    "\n",
    "We need to save this so the same tokenizer used at training can be used to pre-process during serving"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "sziwQgs0uZzx"
   },
   "outputs": [],
   "source": [
    "import pickle\n",
    "with open('./processor_state.pkl', 'wb') as f:\n",
    "    pickle.dump(processor, f)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "4AWJZP3stCta"
   },
   "source": [
    "## Custom Model Prediction Preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "KeR_jDYjuymX"
   },
   "source": [
    "### Copy model and pre-processing object to GCS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 119
    },
    "colab_type": "code",
    "id": "OJwsuGK3ub4S",
    "outputId": "ec1facd0-b502-4fc6-d41a-52fffaf3bb56"
   },
   "outputs": [],
   "source": [
    "!gsutil cp torch_saved_model.pt gs://{BUCKET_NAME}/{MODEL_DIR}/\n",
    "!gsutil cp processor_state.pkl gs://{BUCKET_NAME}/{MODEL_DIR}/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "DZ0H1GKAueAp"
   },
   "source": [
    "### Define Model Class"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "xLvpSsMiufVr",
    "outputId": "4f9aee42-7c50-45f4-f667-72ee55f051e2"
   },
   "outputs": [],
   "source": [
    "%%writefile model.py\n",
    "\n",
    "import os\n",
    "import pickle\n",
    "import numpy as np\n",
    "import torch\n",
    "from torch.autograd import Variable\n",
    "\n",
    "\n",
    "class CustomModelPrediction(object):\n",
    "    def __init__(self, model, processor):\n",
    "        self._model = model\n",
    "        self._processor = processor\n",
    "\n",
    "    def _postprocess(self, predictions):\n",
    "        labels = ['github', 'nytimes', 'techcrunch']\n",
    "        label_indexes = [np.argmax(prediction) \n",
    "                             for prediction in predictions.detach().numpy()]\n",
    "        return [labels[label_index] for label_index in label_indexes]\n",
    "\n",
    "\n",
    "    def predict(self, instances, **kwargs):\n",
    "        preprocessed_data = self._processor.transform(instances)\n",
    "        predictions =  self._model(Variable(torch.Tensor(preprocessed_data).long()))\n",
    "        labels = self._postprocess(predictions)\n",
    "        return labels\n",
    "\n",
    "\n",
    "    @classmethod\n",
    "    def from_path(cls, model_dir):\n",
    "        import torch \n",
    "        import torch_model\n",
    "        model = torch.load(os.path.join(model_dir,'torch_saved_model.pt'))\n",
    "        model.eval()\n",
    "        with open(os.path.join(model_dir, 'processor_state.pkl'), 'rb') as f:\n",
    "            processor = pickle.load(f)\n",
    "        return cls(model, processor)\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "kL-jv3GDg_zD"
   },
   "source": [
    "### Test Model Class Locally"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ynL_ovR32Od0"
   },
   "outputs": [],
   "source": [
    "# Headlines for Predictions\n",
    "\n",
    "techcrunch=[\n",
    "  'Uber shuts down self-driving trucks unit',\n",
    "  'Grover raises €37M Series A to offer latest tech products as a subscription',\n",
    "  'Tech companies can now bid on the Pentagon’s $10B cloud contract'\n",
    "]\n",
    "nytimes=[\n",
    "  '‘Lopping,’ ‘Tips’ and the ‘Z-List’: Bias Lawsuit Explores Harvard’s Admissions',\n",
    "  'A $3B Plan to Turn Hoover Dam into a Giant Battery',\n",
    "  'A MeToo Reckoning in China’s Workplace Amid Wave of Accusations'\n",
    "]\n",
    "github=[\n",
    "  'Show HN: Moon – 3kb JavaScript UI compiler',\n",
    "  'Show HN: Hello, a CLI tool for managing social media',\n",
    "  'Firefox Nightly added support for time-travel debugging'\n",
    "]\n",
    "requests = (techcrunch+nytimes+github)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 170
    },
    "colab_type": "code",
    "id": "zx53zlPK6YIK",
    "outputId": "ec8bcefd-786e-4fc4-d485-f5c0f25438b5"
   },
   "outputs": [],
   "source": [
    "from model import CustomModelPrediction\n",
    "\n",
    "local_prediction = CustomModelPrediction.from_path('.')\n",
    "local_prediction.predict(requests)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Q4BEQqoGyZP1"
   },
   "source": [
    "### Package up files and copy to GCS\n",
    "\n",
    "Create a setup.py script to bundle **model.py**,**preprocess.py** and **torch_model.py**  in a tarball package. Notice that setup.py does not include the dependencies of `model.py` in the package. These dependencies are provided to your model version in other ways:\n",
    "\n",
    "`numpy` and `google-cloud-storage` are both included as part of AI Platform Prediction runtime version 1.15.\n",
    "\n",
    "`torch` is provided in a separate package, as described in a following section.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "DgQ9UJG_u6Jk",
    "outputId": "f7fed653-37f5-4f0f-ce3b-23d92e0e6705"
   },
   "outputs": [],
   "source": [
    "%%writefile setup.py\n",
    "\n",
    "from setuptools import setup\n",
    "\n",
    "REQUIRED_PACKAGES = ['keras']\n",
    "\n",
    "setup(\n",
    "  name=\"text_classification\",\n",
    "  version=\"0.1\",\n",
    "  scripts=[\"preprocess.py\", \"model.py\", \"torch_model.py\"],\n",
    "  include_package_data=True,\n",
    "  install_requires=REQUIRED_PACKAGES\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 649
    },
    "colab_type": "code",
    "id": "IjwWftmpybCI",
    "outputId": "5ba3cbba-6110-464a-9844-7cd7578debc4"
   },
   "outputs": [],
   "source": [
    "!python setup.py sdist\n",
    "!gsutil cp ./dist/text_classification-0.1.tar.gz gs://{BUCKET_NAME}/{PACKAGES_DIR}/text_classification-0.1.tar.gz"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "pU4D8prVtLNI"
   },
   "source": [
    "## Model Deployment to AI Platform"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "WUEA9FKcy8fM"
   },
   "outputs": [],
   "source": [
    "MODEL_NAME='torch_text_classification'\n",
    "MODEL_VERSION='v1'\n",
    "RUNTIME_VERSION='1.15'\n",
    "REGION='us-central1'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 102
    },
    "colab_type": "code",
    "id": "YpZkN4o71PUY",
    "outputId": "4937b721-664c-42e9-e769-215ac7405432"
   },
   "outputs": [],
   "source": [
    "# Delete model version if any\n",
    "! gcloud ai-platform versions delete {MODEL_VERSION} --model {MODEL_NAME} --quiet # run if version already created\n",
    "\n",
    "# Delete model resource\n",
    "! gcloud ai-platform models delete {MODEL_NAME} --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 122
    },
    "colab_type": "code",
    "id": "JBao9v7e1OU0",
    "outputId": "af4b8219-f9c5-4f8f-d9a6-b0acf561f7de"
   },
   "outputs": [],
   "source": [
    "!gcloud beta ai-platform models create {MODEL_NAME} --regions {REGION} --enable-logging --enable-console-logging"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Pytorch compatible packages\n",
    "\n",
    "You need to specify two Python packages when you create your version resource. One of these is the package containing `model.py` that you uploaded to Cloud Storage in a previous step. The other is a package containing the version of PyTorch that you need.\n",
    "\n",
    "Google Cloud provides a collection of PyTorch packages in the `gs://cloud-ai-pytorch` Cloud Storage bucket. These packages are mirrored from the official builds.\n",
    "\n",
    "For this tutorial, use `gs://cloud-ai-pytorch/torch-1.3.1+cpu-cp37-cp37m-linux_x86_64.whl` as your PyTorch package. This provides your version resource with PyTorch 1.3.1 for Python 3.7, built to run on a CPU in Linux.\n",
    "\n",
    "Use the following command to create your version resource:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "WbE2cKVE1PaX"
   },
   "outputs": [],
   "source": [
    "!gcloud beta ai-platform versions create {MODEL_VERSION} --model {MODEL_NAME} \\\n",
    " --origin=gs://{BUCKET_NAME}/{MODEL_DIR}/ \\\n",
    " --python-version=3.7 \\\n",
    " --runtime-version={RUNTIME_VERSION} \\\n",
    " --package-uris=gs://{BUCKET_NAME}/{PACKAGES_DIR}/text_classification-0.1.tar.gz,gs://cloud-ai-pytorch/torch-1.3.1+cpu-cp37-cp37m-linux_x86_64.whl \\\n",
    " --machine-type=mls1-c4-m4 \\\n",
    " --prediction-class=model.CustomModelPrediction"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Jn6EFbPUzUTm"
   },
   "source": [
    "## Online Predictions from AI Platform Prediction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "mNFvql_-zC4N",
    "outputId": "38101952-d6e7-4714-9366-705535670a7c"
   },
   "outputs": [],
   "source": [
    "from googleapiclient import discovery\n",
    "from oauth2client.client import GoogleCredentials\n",
    "import json\n",
    "\n",
    "# JSON format the requests\n",
    "request_data = {'instances': requests}\n",
    "\n",
    "# Authenticate and call CMLE prediction API \n",
    "credentials = GoogleCredentials.get_application_default()\n",
    "api = discovery.build('ml', 'v1', credentials=credentials)\n",
    "\n",
    "parent = 'projects/{}/models/{}/versions/{}'.format(PROJECT_ID, MODEL_NAME, MODEL_VERSION)\n",
    "print(\"Model full name: {}\".format(parent))\n",
    "response = api.projects().predict(body=request_data, name=parent).execute()\n",
    "\n",
    "print(response['predictions'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Cleaning up\n",
    "\n",
    "To clean up all GCP resources used in this project, you can [delete the GCP\n",
    "project](https://cloud.google.com/resource-manager/docs/creating-managing-projects#shutting_down_projects) you used for the tutorial.\n",
    "\n",
    "Alternatively, you can clean up individual resources by running the following\n",
    "commands:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Delete model version resource\n",
    "!gcloud ai-platform versions delete {MODEL_VERSION} --model {MODEL_NAME} --quiet\n",
    "\n",
    "# Delete model resource\n",
    "! gcloud ai-platform models delete {MODEL_NAME} --quiet\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "Pm3d3buc0Bi1"
   },
   "source": [
    "## Authors"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CA7T5kIZ0Bsq"
   },
   "source": [
    "Khalid Salama & Vijay Reddy \n",
    "\n",
    "**Disclaimer**: This is not an official Google product. The sample code provided for an educational purpose.\n"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "Text Classification Using PyTorch and CMLE.ipynb",
   "provenance": [],
   "toc_visible": true,
   "version": "0.3.2"
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
